This notebook aims to reproduce Experiment 2 / Figure 7 in the paper.
%load_ext autoreload
%autoreload 2
import numpy as np
import tensorflow as tf
from copy import deepcopy
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
# Set up plotly
init_notebook_mode(connected=True)
layout = go.Layout(
width=700,
height=500,
margin=go.Margin(l=60, r=60, b=40, t=20),
showlegend=False
)
config={'showLink': False}
colorscale =[[0.0, '#FF881E'], [1.0, '#4E73AE']]
# Make results completely repeatable
seed = 0
np.random.seed(seed)
tf.set_random_seed(seed)
following the implementation details in appendix D in the paper.
from tensorflow.python.keras import Sequential, Model
from tensorflow.python.keras.layers import Dense, Input, Lambda
from src.vae import VAE
from src.rbf import RBFLayer
# Implementation details from Appendix D
input_dim = 784
latent_dim = 2
l2_reg = tf.keras.regularizers.l2(1e-5)
# Create the encoder models
enc_input = Input((input_dim,))
enc_shared = Dense(64, activation='tanh', kernel_regularizer=l2_reg)
enc_mean = Sequential([
enc_shared,
Dense(32, activation='tanh', kernel_regularizer=l2_reg),
Dense(latent_dim, activation='linear', kernel_regularizer=l2_reg)
])
enc_var = Sequential([
enc_shared,
Dense(32, activation='tanh', kernel_regularizer=l2_reg),
Dense(latent_dim, activation='softplus', kernel_regularizer=l2_reg)
])
enc_mean = Model(enc_input, enc_mean(enc_input))
enc_var = Model(enc_input, enc_var(enc_input))
# Create the decoder models
dec_input = Input((latent_dim,))
dec_mean = Sequential([
Dense(32, activation='tanh', kernel_regularizer=l2_reg),
Dense(64, activation='tanh', kernel_regularizer=l2_reg),
Dense(input_dim, activation='sigmoid', kernel_regularizer=l2_reg)
])
dec_mean = Model(dec_input, dec_mean(dec_input))
# Build the RBF network
num_centers = 64
a = 2.0
rbf = RBFLayer([input_dim], num_centers)
dec_var = Model(dec_input, rbf(dec_input))
vae = VAE(enc_mean, enc_var, dec_mean, dec_var, dec_stddev=1.)
from tensorflow.python.keras.datasets import mnist
# train the VAE on MNIST digits
(x_train_all, y_train_all), _ = mnist.load_data()
train_data = [(x, y) for x, y in zip(x_train_all, y_train_all) if y in [0,1]]
x_train, y_train = zip(*train_data)
x_train = np.array(x_train).astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
y_train = np.array(y_train)
# Shuffle the data
p = np.random.permutation(len(x_train))
x_train = x_train[p]
y_train = y_train[p]
without training the generator's variance network. This will be trained separately later.
history = vae.model.fit(x_train,
epochs=50,
batch_size=32,
validation_split=0.1,
verbose=0)
# Plot the losses
data = [go.Scatter(y=history.history['loss'], name='Train Loss'),
go.Scatter(y=history.history['val_loss'], name='Validation Loss')]
plot_layout = deepcopy(layout)
plot_layout['xaxis'] = {'title': 'Epoch'}
plot_layout['yaxis'] = {'title': 'NELBO'}
plot_layout['showlegend'] = True
iplot(go.Figure(data=data, layout=plot_layout), config=config)
# Display a 2D plot of the classes in the latent space
sampled, encoded_mean, encoded_var = vae.encoder.predict(x_train)
# Plot
scatter_plot = go.Scatter(
x = encoded_mean[:300, 0],
y = encoded_mean[:300, 1],
mode = 'markers',
marker = {'color': y_train[:300], 'colorscale': colorscale}
)
data = [scatter_plot]
iplot(go.Figure(data=data, layout=layout), config=config)
For this, we first have to find the centers of the latent points.
from sklearn.cluster import KMeans
# Find the centers of the latent representations
kmeans_model = KMeans(n_clusters=num_centers, random_state=0)
kmeans_model = kmeans_model.fit(encoded_mean)
centers = kmeans_model.cluster_centers_
# Visualize the centers
center_plot = go.Scatter(
x = centers[:, 0],
y = centers[:, 1],
mode = 'markers',
marker = {'color': 'red'}
)
data = [scatter_plot, center_plot]
iplot(go.Figure(data=data, layout=layout), config=config)
# Cluster the latent representations
clustering = dict((c_i, []) for c_i in range(num_centers))
for z_i, c_i in zip(encoded_mean, kmeans_model.predict(encoded_mean)):
clustering[c_i].append(z_i)
bandwidths = []
for c_i, cluster in clustering.items():
if cluster:
diffs = np.array(cluster) - centers[c_i]
avg_dist = np.mean(np.linalg.norm(diffs, axis=1))
bandwidth = 0.5 / (a * avg_dist)**2
else:
bandwidth = 0
bandwidths.append(bandwidth)
bandwidths = np.array(bandwidths)
# Train the RBF
vae.recompile_for_var_training()
rbf_kernel = rbf.get_weights()[0]
rbf.set_weights([rbf_kernel, centers, bandwidths])
history = vae.model.fit(x_train,
epochs=100,
batch_size=32,
validation_split=0.1,
verbose=0)
# Plot the losses
data = [go.Scatter(y=history.history['loss'],
name='Train Loss'),
go.Scatter(y=history.history['val_loss'],
name='Validation Loss')]
plot_layout = deepcopy(layout)
plot_layout['xaxis'] = {'title': 'Epoch'}
plot_layout['yaxis'] = {'title': 'NELBO'}
plot_layout['showlegend'] = True
iplot(go.Figure(data=data, layout=plot_layout), config=config)
for finding a geodesic.
z_start, z_end = encoded_mean[[5,26]]
# Visualize the centers
task_plot = go.Scatter(
x = [z_start[0], z_end[0]],
y = [z_start[1], z_end[1]],
mode = 'markers',
marker = {'color': 'red'}
)
data = [scatter_plot, task_plot]
iplot(go.Figure(data=data, layout=layout), config=config)
from src.util import wrap_model_in_float64
# Get the mean and std predictors
_, mean_output, var_output = vae.decoder.output
sqrt_layer = Lambda(tf.sqrt)
dec_mean = Model(vae.decoder.input, mean_output)
dec_std = Model(vae.decoder.input, sqrt_layer(var_output))
dec_mean = wrap_model_in_float64(dec_mean)
dec_std = wrap_model_in_float64(dec_std)
session = tf.keras.backend.get_session()
from src.plot import plot_magnification_factor
heatmap_z1 = np.linspace(-4, 4, 200)
heatmap_z2 = np.linspace(-4, 4, 200)
heatmap = plot_magnification_factor(session,
heatmap_z1,
heatmap_z2,
dec_mean,
dec_std,
additional_data=[task_plot],
layout=layout,
log_scale=True)
%%time
from src.discrete import find_geodesic_discrete
curve, iterations = find_geodesic_discrete(session,
z_start, z_end,
dec_mean,
std_generator=dec_std,
num_nodes=50,
max_steps=400,
learning_rate=0.01)
print('-' * 20)
from src.plot import plot_latent_curve_iterations
plot_latent_curve_iterations(iterations, [heatmap, scatter_plot], layout,
step_size=10)
%%time
from src.geodesic import find_geodesic
result, iterations = find_geodesic(session, z_start, z_end,
dec_mean, std_generator=dec_std,
initial_nodes=20, max_nodes=1000)
print('-' * 20)
from src.plot import plot_latent_curve_iterations
plot_latent_curve_iterations(iterations, [scatter_plot], layout)
fun_jac is not feasible¶result, iterations = find_geodesic(session, z_start, z_end,
dec_mean, std_generator=dec_std,
initial_nodes=20, max_nodes=1000,
use_fun_jac=True)
I had to kill the cell above after a couple of hours of taking 38GB of RAM.